import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

mps = pd.read_csv('./data/mps_timings.csv')
cuda = pd.read_csv('./data/cuda_timings.csv')
cpu = pd.read_csv('./data/cpu_timings.csv')
t4 = pd.read_csv('./data/t4_timings.csv')

df = pd.concat([mps, cuda, cpu, t4])
df['Model'].replace({'Political_DEBATE_DeBERTa_base_v1.1': 'DEBATE Base',
       'Political_DEBATE_DeBERTa_large_v1.1': 'DEBATE Large',
       'Meta-Llama-3.1-8B-Instruct': 'Llama 3.1 8B'}, inplace = True)
df.replace({'mps': 'Apple M3', 't4':'Tesla T4 GPU', 'cuda':'RTX 3090', 'cpu':'Ryzen 9900x', 't4': 'T4'}, inplace = True)
df.reset_index(drop = True, inplace = True)
# Add row of zeros  for CPU bench for llama
df.loc[len(df)] = ['Llama 3.1 8B', 'Ryzen 9900x', 0, 0, 0, 0]

############
## Figure 6
############
sns.set_palette("colorblind")

hardwareorder = ["RTX 3090", "Apple M3", "T4", "Ryzen 9900x"]

models = ['DEBATE Base', 'DEBATE Large', 'Llama 3.1 8B']
data = df[df['Model'].isin(models)]
data['Hardware'] = pd.Categorical(data['Hardware'], categories=hardwareorder, ordered=True)

# Create a grouped bar plot
plt.figure(figsize=(12, 6))

# Plot the bars
plt.figure(figsize=(12, 6))

# Plot the bars
barplot = sns.barplot(x='Model', y='DPS', hue='Hardware', data=data)

# Add error bars
num_models = len(data['Model'].unique())  # Number of unique models
num_hardware = len(data['Hardware'].unique())  # Number of unique hardware types
bar_width = 0.8 / num_hardware  # Dynamically adjust bar width

# Add DPS numbers above the bars
for bar in barplot.patches:
    bar_height = bar.get_height()  # Get the height of the bar (DPS value)
    bar_x = bar.get_x() + bar.get_width() / 2  # Get the x-coordinate for the text
    if not pd.isna(bar_height) and bar_height > 0:  # Ensure the height is not NaN
        barplot.annotate(
            f'{bar_height:.0f}',  # Format the number as an integer
            (bar_x, bar_height),  # Position the text at the top of the bar
            ha='center',  # Center the text horizontally
            va='bottom',  # Place the text above the bar
            fontsize=16,  # Set font size
            fontweight='bold'  # Make the text bold
        )

# Remove the X-axis label
plt.xlabel('')

# Make X-axis tick labels bold
plt.xticks(fontweight='bold', fontsize = 20)

# Make Y-axis label bold
plt.ylabel('Documents Per Second', fontweight='bold', fontsize = 24)

# Remove the legend title
plt.legend(title='')

# Remove top and right spines
sns.despine()  
plt.grid(False)

# Show the plot
plt.tight_layout()
plt.savefig('./figures/figure_6.png', dpi=300)